import logging
import re

import numpy
import torch
import torch.nn.functional as F

import config
from .face_model import FaceModel
from .models import sifdnet
from .singleton import Singleton
from .transform import create_deepfake_detection_test_transform

LOG = logging.getLogger(__name__)


@Singleton
class DeepfakeDetectionSIFDNet(FaceModel):

    def __init__(self):
        super().__init__()
        self.model = sifdnet()
        self.model_loaded = False
        self.transform = create_deepfake_detection_test_transform(model_cfg=self.model.default_cfg)

    def init_app(self, app):
        pass

    def create_model(self):
        # set TORCH_HOME in your os environment
        model_path = config.get_config('DEEPFAKE_DETECTION_SIFDNET_PATH')
        checkpoint = torch.load(model_path, map_location="cpu")
        state_dict = checkpoint.get("state_dict", checkpoint)
        self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(device='cuda:0').total_memory
            if gpu_memory >= 1 * 1024 * 1024 * 1024:
                self.model.cuda()
        self.model.eval()
        self.model_loaded = True

    def get_embeddings(self, x: numpy.ndarray) -> numpy.ndarray:
        pass

    def get_prediction(self, x: numpy.ndarray, **kwargs) -> numpy.ndarray:
        pass

    def get_predictions(self, x: numpy.ndarray, **kwargs) -> (float, numpy.ndarray):
        if not self.model_loaded and config.get_config('DEEPFAKE_DETECTION') == 'true':
            self.create_model()
        transformed = self.transform(image=x)
        x = transformed["image"]
        x = x.unsqueeze(0)
        if next(self.model.parameters()).is_cuda:
            x = x.cuda(device='cuda:0')
        logits_pred, masks_pred = self.model(x)
        probs_pred = F.softmax(logits_pred, dim=1)
        fake_prob = probs_pred[:, 1].tolist()[0]
        # y_pred = torch.argmax(logits_pred, dim=1).detach().cpu().numpy()
        masks_pred = masks_pred.squeeze().detach().cpu().numpy()
        return fake_prob, masks_pred
